package com.xunmall.gateway.interceptor;

import com.xunmall.base.exception.BadRequestException;
import com.xunmall.base.exception.ForbiddenException;
import com.xunmall.base.util.WebUtils;
import com.xunmall.security.GlobalSessionUtil;
import com.xunmall.security.SpringSecurityUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 检查url上的公司uuid是否与session中一致，避免越权访问
 * 检查url上的职员uuid是否与session中一致，避免越权访问
 */
public class UrlCheckFilter implements Filter {
    private static final String companyUuidRegex = "/v1/co/(\\w+)/.*";
    private static final String emplUuidRegex = "/v1/em/(\\w+)/.*";

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {
        // left blank intentionally
    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletResponse httpResponse = (HttpServletResponse) response;
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;

        if (!HttpMethod.OPTIONS.name().equalsIgnoreCase(httpServletRequest.getMethod())) {
            String url = httpServletRequest.getRequestURI();

            // 针对/v1/co/开头的url进行越权检查
            if (StringUtils.isNotEmpty(url) && url.startsWith("/v1/co/")) {
                Boolean multiCompany = GlobalSessionUtil.getMultiCompany();
                String sessionCompanyUuid = GlobalSessionUtil.getCurCompUuid();
                //如果是多公司 并且未选择公司 则跳转到公司首页
                if (StringUtils.isEmpty(sessionCompanyUuid)) {
                    if (multiCompany) {
                        BadRequestException bex = new BadRequestException("error.gateway.0005", this);
                        WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                        return;
                    } else {
                        BadRequestException bex = new BadRequestException("error.gateway.0010", this);
                        WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                        return;
                    }
                }

                String urlCompanyUuid = getCompanyUuid(url);
                if(StringUtils.isEmpty(urlCompanyUuid)) { // url中没有公司uuid
                    BadRequestException bex = new BadRequestException("error.gateway.0050", this);
                    WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                    return;
                }
                if (SpringSecurityUtils.getUserId() != null) { // 已登录
                    if (StringUtils.isEmpty(sessionCompanyUuid)) { // session中没有compyUuid
                        BadRequestException bex = new BadRequestException("error.gateway.0052", this);
                        WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                        return;
                    } else {
                        if (!urlCompanyUuid.equals(sessionCompanyUuid)) { // 越权访问
                            ForbiddenException fex = new ForbiddenException("error.gateway.0051", this);
                            WebUtils.sendJsonError(httpServletRequest, httpResponse, fex, HttpStatus.FORBIDDEN);
                            return;
                        }
                    }
                } else { // 未登录
                    ForbiddenException fex = new ForbiddenException("error.0011", this);
                    WebUtils.sendJsonError(httpServletRequest, httpResponse, fex, HttpStatus.FORBIDDEN);
                    return;
                }
            }
            // 针对/v1/em/开头的url进行越权检查
            if (StringUtils.isNotEmpty(url) && url.startsWith("/v1/em/")) {
                String urlEmplUuid = getEmplUuid(url);
                if (StringUtils.isEmpty(urlEmplUuid)) { // 没有职员uuid
                    BadRequestException bex = new BadRequestException("error.gateway.0056", this);
                    WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                    return;
                }
                if (SpringSecurityUtils.getUserId() != null) { // 已登录
                    String seesionEmplUuid = GlobalSessionUtil.getEmplUuid();
                    if (StringUtils.isEmpty(seesionEmplUuid)) { // session中没有emplUuid
                        BadRequestException bex = new BadRequestException("error.gateway.0058", this);
                        WebUtils.sendJsonError(httpServletRequest, httpResponse, bex, HttpStatus.BAD_REQUEST);
                        return;
                    } else {
                        if (!urlEmplUuid.equals(seesionEmplUuid)) { // 越权访问
                            ForbiddenException fex = new ForbiddenException("error.gateway.0057", this);
                            WebUtils.sendJsonError(httpServletRequest, httpResponse, fex, HttpStatus.FORBIDDEN);
                            return;
                        }
                    }
                } else { // 未登录
                    ForbiddenException fex = new ForbiddenException("error.0011", this);
                    WebUtils.sendJsonError(httpServletRequest, httpResponse, fex, HttpStatus.FORBIDDEN);
                    return;
                }
            }
        }

        chain.doFilter(request, response);
    }
    private String getCompanyUuid(String url) {
        Pattern pattern = Pattern.compile(companyUuidRegex);
        Matcher matcher = pattern.matcher(url);
        if(matcher.find()) {
            return matcher.group(1);
        } else {
            return null;
        }
    }
    private String getEmplUuid(String url) {
        Pattern pattern = Pattern.compile(emplUuidRegex);
        Matcher matcher = pattern.matcher(url);
        if(matcher.find()) {
            return matcher.group(1);
        } else {
            return null;
        }
    }

    @Override
    public void destroy() {
        // left blank intentionally
    }
}
